[generate] fix breaking change for patch#29976
Merged
ArthurZucker merged 12 commits intomainfrom Apr 2, 2024
Merged
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
gante
reviewed
Apr 1, 2024
tests/generation/test_utils.py
Outdated
Comment on lines
721
to
732
| input_embeds = model.get_input_embeddings()(input_ids) | ||
| beam_kwargs.update({"inputs_embeds": input_embeds}) | ||
| output_generate2 = self._beam_sample_generate( | ||
| model=model, | ||
| input_ids=None, | ||
| attention_mask=attention_mask, | ||
| max_length=max_length, | ||
| beam_kwargs=beam_kwargs, | ||
| logits_warper_kwargs=logits_warper_kwargs, | ||
| ) | ||
|
|
||
| torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2) |
Contributor
There was a problem hiding this comment.
This can't be tested in the mixin -- the vast majority of the models don't support passing inputs_embeds to generate, they need would some changes in prepare_inputs_for_generate
Collaborator
Author
There was a problem hiding this comment.
Alright I'll check the signature
Collaborator
Author
|
Failing test is unrelated |
ArthurZucker
added a commit
that referenced
this pull request
Apr 2, 2024
* fix bug and add tests * nit * otherway to get the cur len instead of attention mask * more places where this might have been broken * nit * oups * inputs_embeds vs input_embeds * test generated outptus * style * nit * fix * skip failing biogpt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
A bug was introduced by #29467 pretty much unrelated to cache positions.
This fixes #29968
cc @gante and @zucchini-nlp. The testing suite is missing this particular test for all generation strategies